#!/usr/bin/env python3
"""
Phase 4g: Two analyses
1. Multiple testing correction across all hypothesis tests
2. CCA commodity vs non-commodity decomposition (is the CCA tipping point driven by oil?)
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats

FOLLOWUP_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup")
PROJECT_DIR = FOLLOWUP_DIR.parent
sys.path.insert(0, str(FOLLOWUP_DIR))
sys.path.insert(1, str(PROJECT_DIR))

from src.model import PanelGLS
from src.macro import EBA_COUNTRIES, SSA_COUNTRIES, filter_eba_sample

OUTPUT_DIR = FOLLOWUP_DIR / "output" / "tables"
PROCESSED_DIR = FOLLOWUP_DIR / "data" / "processed"

panel = pd.read_csv(PROCESSED_DIR / "full_panel.csv")
est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()
full_sample = filter_eba_sample(est, extended=True, expansion=True)

demo_vars = ['Z_1', 'Z_2', 'Z_3']
baseline_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth', 'nfa_gdp_lag',
                     'log_rel_opw', 'health_exp_gdp']


def run_gls(df, dep_var, indep_vars):
    """Run PanelGLS and return a dict of results."""
    comp = df.dropna(subset=[dep_var] + indep_vars).copy()
    y = comp[dep_var].values
    X = comp[indep_vars].values
    gls = PanelGLS()
    gls.fit(y, X, comp['iso3'].values, comp['year'].values)
    return {
        'r_squared': gls.r_squared,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'coefficients': dict(zip(indep_vars, gls.beta)),
        'std_errors': dict(zip(indep_vars, gls.se)),
        'p_values': dict(zip(indep_vars, gls.pvalues)),
    }


# ═══════════════════════════════════════════════════════════════════════════
# 1. MULTIPLE TESTING CORRECTION
# ═══════════════════════════════════════════════════════════════════════════
print("=" * 70)
print("MULTIPLE TESTING CORRECTION")
print("=" * 70)

# Catalog of all distinct hypothesis tests conducted in the followup
# Each test: (name, type, raw_p, n_obs, description)
# Type: "confirmatory" = replication/pre-specified; "exploratory" = post-hoc
tests = [
    # Baseline model (confirmatory — replicating Koomen-Wicht on broader sample)
    ("Z₁ baseline", "confirmatory", 0.0009, 2730, "Z₁ in 140c baseline model"),
    ("Z₂ baseline", "confirmatory", 0.0003, 2730, "Z₂ in 140c baseline model"),
    ("Z₃ baseline", "confirmatory", 0.0001, 2730, "Z₃ in 140c baseline model"),

    # Extended model interactions (confirmatory — pre-specified in original paper)
    ("Z₁×KAOPEN", "confirmatory", 0.039, 1626, "Z₁×KAOPEN in extended model"),
    ("Z₂×KAOPEN", "confirmatory", 0.021, 1626, "Z₂×KAOPEN in extended model"),
    ("Z₃×KAOPEN", "confirmatory", 0.013, 1626, "Z₃×KAOPEN in extended model"),
    ("Joint Z×KAOPEN", "confirmatory", 0.001, 1626, "Joint F-test for Z×KAOPEN"),

    # Three-way interactions (exploratory)
    ("High-income Z×K joint=0", "exploratory", 0.280, 1626, "Joint F: high-income interactions zero"),
    ("Middle-income Z×K joint", "exploratory", 0.0001, 1626, "Joint F: middle-income interactions"),
    ("Low-income Z×K joint", "exploratory", 0.0001, 1626, "Joint F: low-income interactions"),

    # Pension model (exploratory)
    ("Joint Z×pension (EBA-49)", "exploratory", 0.038, 597, "Joint F: pension interactions on EBA-49"),
    ("Z₃×pension horse race", "exploratory", 0.043, 597, "Z₃×pension in horse race vs KAOPEN"),
    ("Z₁×pension (full 140)", "exploratory", 0.008, 750, "Z₁×pension on full sample"),
    ("Z₂×pension (full 140)", "exploratory", 0.005, 750, "Z₂×pension on full sample"),
    ("Z₃×pension (full 140)", "exploratory", 0.003, 750, "Z₃×pension on full sample"),

    # Nonlinearity (exploratory — motivated by theory but not pre-specified)
    ("NFA² quadratic", "exploratory", 0.033, 2730, "NFA-squared in baseline"),
    ("NFA creditor", "exploratory", 0.080, 2730, "Creditor (NFA>0) coefficient"),
    ("LE² quadratic", "exploratory", 0.059, 2730, "Life expectancy squared"),

    # Interest rate channel (exploratory)
    ("Bond yield diff", "exploratory", 0.031, 679, "Real 10y bond yield differential"),
    ("Term spread", "exploratory", 0.001, 663, "Term spread coefficient"),
    ("Short rate diff", "exploratory", 0.590, 679, "Real short rate differential (null)"),
    ("Two-stage S2", "exploratory", 0.019, 679, "Carvalho two-stage, Stage 2"),

    # Horse race: KAOPEN vs GDP (exploratory)
    ("KAOPEN joint in horse race", "exploratory", 0.0001, 1824, "KAOPEN interactions joint in HR"),
    ("GDP joint in horse race", "exploratory", 0.0001, 1824, "GDP interactions joint in HR"),

    # Savings/Investment channels (exploratory)
    ("Savings channel", "exploratory", 0.001, 2730, "Demographics → savings ΔR²"),
    ("Investment channel", "exploratory", 0.001, 2730, "Demographics → investment ΔR²"),
]

# Separate confirmatory and exploratory
confirmatory = [(n, t, p, no, d) for n, t, p, no, d in tests if t == "confirmatory"]
exploratory = [(n, t, p, no, d) for n, t, p, no, d in tests if t == "exploratory"]

print(f"\nTotal tests: {len(tests)}")
print(f"  Confirmatory: {len(confirmatory)}")
print(f"  Exploratory: {len(exploratory)}")

# Bonferroni correction
n_tests_all = len(tests)
n_tests_exp = len(exploratory)

print(f"\n--- Bonferroni Correction (all {n_tests_all} tests) ---")
bonferroni_alpha = 0.05 / n_tests_all
print(f"  Adjusted alpha: {bonferroni_alpha:.4f}")

# Holm-Bonferroni (step-down, less conservative)
all_p = sorted([(p, n, t, no, d) for n, t, p, no, d in tests])

print(f"\n--- Holm-Bonferroni Step-Down (all tests) ---")
print(f"  {'Rank':>4s} {'Test':<35s} {'Type':<14s} {'Raw p':>8s} {'Holm α':>8s} {'Sig?':>5s}")
print("  " + "-" * 78)

holm_results = []
for rank, (p, name, typ, nobs, desc) in enumerate(all_p, 1):
    holm_alpha = 0.05 / (n_tests_all - rank + 1)
    sig = "YES" if p < holm_alpha else "no"
    print(f"  {rank:4d} {name:<35s} {typ:<14s} {p:8.4f} {holm_alpha:8.4f} {sig:>5s}")
    holm_results.append({
        'rank': rank,
        'test': name,
        'type': typ,
        'raw_p': p,
        'holm_alpha': holm_alpha,
        'holm_significant': p < holm_alpha,
        'bonferroni_significant': p < bonferroni_alpha,
        'n_obs': nobs,
        'description': desc,
    })

# Benjamini-Hochberg FDR correction
print(f"\n--- Benjamini-Hochberg FDR (q=0.05, all tests) ---")
print(f"  {'Rank':>4s} {'Test':<35s} {'Raw p':>8s} {'BH crit':>8s} {'Sig?':>5s}")
print("  " + "-" * 64)

bh_max_sig_rank = 0
for rank, (p, name, typ, nobs, desc) in enumerate(all_p, 1):
    bh_critical = (rank / n_tests_all) * 0.05
    if p <= bh_critical:
        bh_max_sig_rank = rank

for rank, (p, name, typ, nobs, desc) in enumerate(all_p, 1):
    bh_critical = (rank / n_tests_all) * 0.05
    sig = "YES" if rank <= bh_max_sig_rank else "no"
    print(f"  {rank:4d} {name:<35s} {p:8.4f} {bh_critical:8.4f} {sig:>5s}")
    # Update holm_results
    holm_results[rank-1]['bh_critical'] = bh_critical
    holm_results[rank-1]['bh_significant'] = rank <= bh_max_sig_rank

# Summary
print(f"\n--- Summary ---")
hr_df = pd.DataFrame(holm_results)

n_raw = sum(1 for r in holm_results if r['raw_p'] < 0.05)
n_bonf = sum(1 for r in holm_results if r['bonferroni_significant'])
n_holm = sum(1 for r in holm_results if r['holm_significant'])
n_bh = sum(1 for r in holm_results if r['bh_significant'])

print(f"  Significant at raw p<0.05:      {n_raw}/{len(tests)}")
print(f"  Significant after Bonferroni:    {n_bonf}/{len(tests)}")
print(f"  Significant after Holm:          {n_holm}/{len(tests)}")
print(f"  Significant after BH (FDR):      {n_bh}/{len(tests)}")

print(f"\n  Tests surviving Bonferroni (most conservative):")
for r in holm_results:
    if r['bonferroni_significant']:
        print(f"    {r['test']:<35s} p={r['raw_p']:.4f} ({r['type']})")

print(f"\n  Tests surviving Holm but not Bonferroni:")
for r in holm_results:
    if r['holm_significant'] and not r['bonferroni_significant']:
        print(f"    {r['test']:<35s} p={r['raw_p']:.4f} ({r['type']})")

print(f"\n  Tests surviving BH (FDR) but not Holm:")
for r in holm_results:
    if r['bh_significant'] and not r['holm_significant']:
        print(f"    {r['test']:<35s} p={r['raw_p']:.4f} ({r['type']})")

print(f"\n  Tests significant at raw p<0.05 but NOT after any correction:")
for r in holm_results:
    if r['raw_p'] < 0.05 and not r['bh_significant']:
        print(f"    {r['test']:<35s} p={r['raw_p']:.4f} ({r['type']})")

# Separate analysis for confirmatory only
print(f"\n\n--- Confirmatory Tests Only (Bonferroni with m={len(confirmatory)}) ---")
conf_alpha = 0.05 / len(confirmatory)
print(f"  Adjusted alpha: {conf_alpha:.4f}")
for name, typ, p, nobs, desc in confirmatory:
    sig = "YES" if p < conf_alpha else "no"
    print(f"  {name:<35s} p={p:.4f}  sig? {sig}")

hr_df.to_csv(OUTPUT_DIR / 'multiple_testing_correction.csv', index=False)
print(f"\n  Saved to {OUTPUT_DIR / 'multiple_testing_correction.csv'}")


# ═══════════════════════════════════════════════════════════════════════════
# 2. CCA COMMODITY vs NON-COMMODITY DECOMPOSITION
# ═══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("CCA DECOMPOSITION: COMMODITY vs NON-COMMODITY")
print("=" * 70)

# CCA countries with major commodity export dependence
cca_commodity = {'AZE', 'KAZ', 'RUS', 'TKM', 'UZB'}  # Oil/gas/mining dominant
cca_non_commodity = {'ARM', 'BLR', 'GEO', 'KGZ', 'MDA', 'MNG', 'TJK', 'UKR'}  # Remittances, services, agriculture

cca_all = cca_commodity | cca_non_commodity

# Verify which are in our sample
cca_in_sample = set(full_sample['iso3'].unique()) & cca_all
print(f"\nCCA countries in sample: {sorted(cca_in_sample)}")
print(f"  Commodity: {sorted(cca_in_sample & cca_commodity)}")
print(f"  Non-commodity: {sorted(cca_in_sample & cca_non_commodity)}")

base_vars = demo_vars + baseline_controls
base_sample = full_sample.dropna(subset=base_vars + ['ca_gdp']).copy()

# Test 1: Full sample (reference)
r_full = run_gls(base_sample, 'ca_gdp', base_vars)
print(f"\n  Full sample (137 countries): Z₁={r_full['coefficients']['Z_1']:.2f} (p={r_full['p_values']['Z_1']:.4f})")

# Test 2: Drop ALL CCA
sub_no_cca = base_sample[~base_sample['iso3'].isin(cca_all)].copy()
r_no_cca = run_gls(sub_no_cca, 'ca_gdp', base_vars)
print(f"  Drop all CCA ({sub_no_cca['iso3'].nunique()} countries): Z₁={r_no_cca['coefficients']['Z_1']:.2f} (p={r_no_cca['p_values']['Z_1']:.4f})")

# Test 3: Drop CCA commodity only
sub_no_cca_com = base_sample[~base_sample['iso3'].isin(cca_commodity)].copy()
r_no_cca_com = run_gls(sub_no_cca_com, 'ca_gdp', base_vars)
print(f"  Drop CCA commodity ({sub_no_cca_com['iso3'].nunique()} countries): Z₁={r_no_cca_com['coefficients']['Z_1']:.2f} (p={r_no_cca_com['p_values']['Z_1']:.4f})")

# Test 4: Drop CCA non-commodity only
sub_no_cca_noncom = base_sample[~base_sample['iso3'].isin(cca_non_commodity)].copy()
r_no_cca_noncom = run_gls(sub_no_cca_noncom, 'ca_gdp', base_vars)
print(f"  Drop CCA non-commodity ({sub_no_cca_noncom['iso3'].nunique()} countries): Z₁={r_no_cca_noncom['coefficients']['Z_1']:.2f} (p={r_no_cca_noncom['p_values']['Z_1']:.4f})")

# Full table
print(f"\n--- CCA Decomposition Results ---")
print(f"  {'Sample':<35s} {'N_c':>4s} {'N_obs':>6s} {'R²':>7s} {'Z₁':>8s} {'p(Z₁)':>8s} {'Z₂':>8s} {'p(Z₂)':>8s} {'Z₃':>8s} {'p(Z₃)':>8s}")
print("  " + "-" * 100)

results_list = [
    ("Full sample", r_full),
    ("Drop all CCA (13)", r_no_cca),
    ("Drop CCA commodity (5)", r_no_cca_com),
    ("Drop CCA non-commodity (8)", r_no_cca_noncom),
]

for name, r in results_list:
    z1c, z1p = r['coefficients']['Z_1'], r['p_values']['Z_1']
    z2c, z2p = r['coefficients']['Z_2'], r['p_values']['Z_2']
    z3c, z3p = r['coefficients']['Z_3'], r['p_values']['Z_3']
    print(f"  {name:<35s} {r['n_countries']:4d} {r['n_obs']:6d} {r['r_squared']:7.4f} {z1c:8.2f} {z1p:8.4f} {z2c:8.3f} {z2p:8.4f} {z3c:8.4f} {z3p:8.4f}")

# Test 5: Add commodity export control
# We can approximate commodity dependence with a dummy or use fuel exports data
# For now, use a simpler approach: add a CCA commodity dummy interaction
base_sample['is_cca_commodity'] = base_sample['iso3'].isin(cca_commodity).astype(float)
base_sample['is_cca_non_commodity'] = base_sample['iso3'].isin(cca_non_commodity).astype(float)

# Test: baseline + CCA dummies (do demographics remain significant controlling for CCA-specific effects?)
cca_control_vars = base_vars + ['is_cca_commodity', 'is_cca_non_commodity']
r_cca_controlled = run_gls(base_sample, 'ca_gdp', cca_control_vars)
print(f"\n  With CCA dummies as controls:")
print(f"  {'+ CCA dummies':<35s} {r_cca_controlled['n_countries']:4d} {r_cca_controlled['n_obs']:6d} {r_cca_controlled['r_squared']:7.4f} {r_cca_controlled['coefficients']['Z_1']:8.2f} {r_cca_controlled['p_values']['Z_1']:8.4f} {r_cca_controlled['coefficients']['Z_2']:8.3f} {r_cca_controlled['p_values']['Z_2']:8.4f} {r_cca_controlled['coefficients']['Z_3']:8.4f} {r_cca_controlled['p_values']['Z_3']:8.4f}")
print(f"    CCA commodity dummy:     {r_cca_controlled['coefficients']['is_cca_commodity']:.3f} (p={r_cca_controlled['p_values']['is_cca_commodity']:.4f})")
print(f"    CCA non-commodity dummy: {r_cca_controlled['coefficients']['is_cca_non_commodity']:.3f} (p={r_cca_controlled['p_values']['is_cca_non_commodity']:.4f})")

# Test 6: Does Russia alone drive the CCA effect? (Russia is large and has
# dramatic demographics + commodity CA)
sub_no_russia = base_sample[base_sample['iso3'] != 'RUS'].copy()
r_no_russia = run_gls(sub_no_russia, 'ca_gdp', base_vars)
print(f"\n  Drop Russia only:")
print(f"  {'Drop RUS only':<35s} {r_no_russia['n_countries']:4d} {r_no_russia['n_obs']:6d} {r_no_russia['r_squared']:7.4f} {r_no_russia['coefficients']['Z_1']:8.2f} {r_no_russia['p_values']['Z_1']:8.4f} {r_no_russia['coefficients']['Z_2']:8.3f} {r_no_russia['p_values']['Z_2']:8.4f} {r_no_russia['coefficients']['Z_3']:8.4f} {r_no_russia['p_values']['Z_3']:8.4f}")

# Test 7: Drop Russia + Kazakhstan (two largest CCA commodity countries)
sub_no_rus_kaz = base_sample[~base_sample['iso3'].isin({'RUS', 'KAZ'})].copy()
r_no_rus_kaz = run_gls(sub_no_rus_kaz, 'ca_gdp', base_vars)
print(f"  {'Drop RUS + KAZ':<35s} {r_no_rus_kaz['n_countries']:4d} {r_no_rus_kaz['n_obs']:6d} {r_no_rus_kaz['r_squared']:7.4f} {r_no_rus_kaz['coefficients']['Z_1']:8.2f} {r_no_rus_kaz['p_values']['Z_1']:8.4f} {r_no_rus_kaz['coefficients']['Z_2']:8.3f} {r_no_rus_kaz['p_values']['Z_2']:8.4f} {r_no_rus_kaz['coefficients']['Z_3']:8.4f} {r_no_rus_kaz['p_values']['Z_3']:8.4f}")

# Descriptive: What do CCA CA balances and demographics look like?
print(f"\n\n--- CCA Country Characteristics ---")
print(f"  {'Country':<6s} {'Mean CA/GDP':>11s} {'Mean Z₁':>9s} {'Mean KAOPEN':>12s} {'N_obs':>6s}")
print("  " + "-" * 50)

for iso in sorted(cca_in_sample):
    sub = base_sample[base_sample['iso3'] == iso]
    if len(sub) > 0:
        tag = " [COM]" if iso in cca_commodity else ""
        print(f"  {iso}{tag:<6s} {sub['ca_gdp'].mean():11.2f} {sub['Z_1'].mean():9.4f} {sub['kaopen'].mean():12.2f} {len(sub):6d}")

# Mean for CCA vs rest
cca_obs = base_sample[base_sample['iso3'].isin(cca_all)]
non_cca = base_sample[~base_sample['iso3'].isin(cca_all)]
print(f"\n  CCA mean:     CA/GDP={cca_obs['ca_gdp'].mean():6.2f}, Z₁={cca_obs['Z_1'].mean():.4f}, KAOPEN={cca_obs['kaopen'].mean():.2f}")
print(f"  Non-CCA mean: CA/GDP={non_cca['ca_gdp'].mean():6.2f}, Z₁={non_cca['Z_1'].mean():.4f}, KAOPEN={non_cca['kaopen'].mean():.2f}")

# Save CCA results
cca_rows = []
for name, r in [
    ("Full sample", r_full),
    ("Drop all CCA", r_no_cca),
    ("Drop CCA commodity", r_no_cca_com),
    ("Drop CCA non-commodity", r_no_cca_noncom),
    ("With CCA dummies", r_cca_controlled),
    ("Drop Russia only", r_no_russia),
    ("Drop Russia + Kazakhstan", r_no_rus_kaz),
]:
    row = {'sample': name, 'n_countries': r['n_countries'], 'n_obs': r['n_obs'], 'r_squared': r['r_squared']}
    for v in demo_vars:
        row[f'{v}_coef'] = r['coefficients'][v]
        row[f'{v}_pval'] = r['p_values'][v]
    cca_rows.append(row)

cca_df = pd.DataFrame(cca_rows)
cca_df.to_csv(OUTPUT_DIR / 'cca_decomposition.csv', index=False)
print(f"\n  Saved to {OUTPUT_DIR / 'cca_decomposition.csv'}")

print("\nDone.")
